'''
Author: Zhuangyu
Date: 2022-10-21 15:32:15
LastEditTime: 2023-05-24 00:18:12
'''
from __future__ import print_function, division
import time
#Importing relevant modules and packages
import numpy as np
import matplotlib.pyplot as plt
import random
from Parameters import *
from scipy import integrate, linalg, interpolate
from tqdm import tqdm
from casadi import *
import mpctools as mpc
import casadi
import scipy.io as sio  
import van_Genuchten as vg
from RichardsModel_1D import *
from prejacobian import F1,F2,F3,Fy
import scipy.linalg.interpolative as sli

#Space and time parameters
Nr,Nt,Nz,dr,dt,dz,Np,Na,Np,N_obs,Hz=circular_parameters()
DeltaT,Nsim=time_parameters()
p=Loam() #Soil parameters

def RANK(M):
    r = sli.estimate_rank(M, eps=1e-8)
    return r

size=Nsim-1
alpha = 1

def ode_sensitivity_States(x,y,u,ui,sigma_w,sigma_v,rank_xp):
    
    Ny = y.shape[1]
    Sxx = np.zeros([size+1,Nz,Nz])
    Sxx[0,:,:] = np.eye(Nz)
    Syx = np.zeros([size+1,N_obs,Nz])
    
    dFdx = np.zeros([size,Nz,Nz])
    
    dHdx = np.zeros([size+1,N_obs,Nz])

        
    
    for i in range (size):
        dFdx[i,:,:]=F1(x[i,:],u[i,:],ui[i,:],DeltaT)
        dHdx[i,:,:]=Fy(x[i,:],u[i,:],ui[i,:],DeltaT)
                 
        Sxx[i+1,:,:]=mtimes(dFdx[i,:,:],Sxx[i,:,:])
        
        Syx[i,:,:]=mtimes(dHdx[i,:,:],Sxx[i,:,:])
    
    for i in range (size+1):
        if i==0:
            Sall = Syx[i,:,:]
        else:
            Sall = np.concatenate((Sall,Syx[i,:,:]),axis=0)
    
    rank = RANK(Sall)
    rank_S = rank
    # Variable selection procudure
    Rl = np.zeros([Nz+1,(size+1)*(Ny),Nz])
    Zl = np.zeros([Nz,(size+1)*Ny,Nz])
    SumCol = np.zeros([Nz,Nz])
    Flag = np.zeros([Nz,1],dtype=int)
    
    Rl[0,:,:] = Sall
    
    Xl = np.zeros([(size+1)*Ny,Nz])
    
    degalpha =  alpha*np.sqrt(sigma_w**2+sigma_v**2) #np.sqrt((size+1)*Ny) *
#    print('degalpha',degalpha)
    if rank_xp:
    #'''
    # Rank all
        for i in range(rank):
            for j in range(Nz):
                for k in range((size+1)*Ny):
                    SumCol[i,j] = SumCol[i,j] + Rl[i,k,j]**2
                SumCol[i,j] = np.sqrt(SumCol[i,j])
            Flag[i,0] = 0
            for i1 in range(Nz):
                if SumCol[i,Flag[i,0]] < SumCol[i,i1]:
                   Flag[i,0] = i1
                            
    # A prescribed value for sensitivity to terminate
 
            # if SumCol[i,Flag[i,0]] < degalpha:
            #     rank = i
            #     break
            # else:
            Xl[:,i] = Sall[:,Flag[i,0]]
                                
    # select all variables or singular Xl: Terminate the selection
            if i == rank-1:
                break
            else:
                 rank_xx = RANK(mpc.mtimes(Xl[:,0:i+1].T,Xl[:,0:i+1]))
                 if rank_xx == (i+1):
                     Zl[i,:,:] = mpc.mtimes(mpc.mtimes(mpc.mtimes(Xl[:,0:i+1],np.linalg.inv(mpc.mtimes(Xl[:,0:i+1].T,Xl[:,0:i+1]))),Xl[:,0:i+1].T),Sall)
                     Rl[i+1,:,:] = Sall-Zl[i,:,:]
    # Eliminate residual non-zero values
                     for i1 in range(i+1):
                         for i2 in range((size+1)*Ny):
                             Rl[i+1,i2,Flag[i1,0]] = 0
                 else:
                     rank = i+1
                     break
    maxcol = np.amax(SumCol,axis=1)
    print('maxcol =',maxcol)
#    Js = functools.reduce(operator.mul, maxcol[0:rank], 1)
#    Js = np.sum(maxcol)
    Js = rank*np.linalg.norm(maxcol)
#    print('Js =',Js)
    Result = np.concatenate((rank,Flag,rank_S,Js),axis = None)
    return Result